import torch
from torchvision.models import mobilenet_v2
from torch import nn


#https://github.com/Lee-Gihun/FedNTD/blob/master/train_tools/models/fedavgnet.py

import torch
from torch import nn

class FedNet(nn.Module):
    def __init__(self, num_classes):
        super(FedNet, self).__init__()
        self.conv2d_1 = nn.Conv2d(3, 32, kernel_size=5, padding=2)
        self.max_pooling = nn.MaxPool2d(2, stride=2)
        self.conv2d_2 = nn.Conv2d(32, 64, kernel_size=5, padding=2)
        self.flatten = nn.Flatten()
        self.relu = nn.ReLU()

        # Calculate the size of the feature map after the conv layers
        dummy_input = torch.randn(1, 3, 28, 28)  # Assuming input size of 28x28 for MedMNIST
        dummy_output = self._forward_conv_layers(dummy_input)
        num_features = dummy_output.numel()

        self.linear_1 = nn.Linear(num_features, 512)
        self.classifier = nn.Linear(512, num_classes)

    def _forward_conv_layers(self, x):
        x = self.conv2d_1(x)
        x = self.relu(x)
        x = self.max_pooling(x)
        x = self.conv2d_2(x)
        x = self.relu(x)
        x = self.max_pooling(x)
        return x

    def forward(self, x):
        x = self._forward_conv_layers(x)
        x = self.flatten(x)
        x = self.relu(self.linear_1(x))
        x = self.classifier(x)
        return x
